# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS.  All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.

"""Unit tests for the eval_scores module.
"""

import os
import shutil
import tempfile
import unittest

import pydub

from . import data_access
from . import eval_scores
from . import eval_scores_factory
from . import signal_processing


class TestEvalScores(unittest.TestCase):
  """Unit tests for the eval_scores module.
  """

  def setUp(self):
    """Create temporary output folder and two audio track files."""
    self._output_path = tempfile.mkdtemp()

    # Create fake reference and tested (i.e., APM output) audio track files.
    silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
    fake_reference_signal = (
        signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence))
    fake_tested_signal = (
        signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence))

    # Save fake audio tracks.
    self._fake_reference_signal_filepath = os.path.join(
        self._output_path, 'fake_ref.wav')
    signal_processing.SignalProcessingUtils.SaveWav(
        self._fake_reference_signal_filepath, fake_reference_signal)
    self._fake_tested_signal_filepath = os.path.join(
        self._output_path, 'fake_test.wav')
    signal_processing.SignalProcessingUtils.SaveWav(
        self._fake_tested_signal_filepath, fake_tested_signal)

  def tearDown(self):
    """Recursively delete temporary folder."""
    shutil.rmtree(self._output_path)

  def testRegisteredClasses(self):
    # Evaluation score names to exclude (tested separately).
    exceptions = ['thd', 'echo_metric']

    # Preliminary check.
    self.assertTrue(os.path.exists(self._output_path))

    # Check that there is at least one registered evaluation score worker.
    registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
    self.assertIsInstance(registered_classes, dict)
    self.assertGreater(len(registered_classes), 0)

    # Instance evaluation score workers factory with fake dependencies.
    eval_score_workers_factory = (
        eval_scores_factory.EvaluationScoreWorkerFactory(
            polqa_tool_bin_path=os.path.join(
                os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
            echo_metric_tool_bin_path=None
        ))
    eval_score_workers_factory.SetScoreFilenamePrefix('scores-')

    # Try each registered evaluation score worker.
    for eval_score_name in registered_classes:
      if eval_score_name in exceptions:
        continue

      # Instance evaluation score worker.
      eval_score_worker = eval_score_workers_factory.GetInstance(
          registered_classes[eval_score_name])

      # Set fake input metadata and reference and test file paths, then run.
      eval_score_worker.SetReferenceSignalFilepath(
          self._fake_reference_signal_filepath)
      eval_score_worker.SetTestedSignalFilepath(
          self._fake_tested_signal_filepath)
      eval_score_worker.Run(self._output_path)

      # Check output.
      score = data_access.ScoreFile.Load(eval_score_worker.output_filepath)
      self.assertTrue(isinstance(score, float))

  def testTotalHarmonicDistorsionScore(self):
    # Init.
    pure_tone_freq = 5000.0
    eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-')
    eval_score_worker.SetInputSignalMetadata({
        'signal': 'pure_tone',
        'frequency': pure_tone_freq,
        'test_data_gen_name': 'identity',
        'test_data_gen_config': 'default',
    })
    template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)

    # Create 3 test signals: pure tone, pure tone + white noise, white noise
    # only.
    pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone(
        template, pure_tone_freq)
    white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
        template)
    noisy_tone = signal_processing.SignalProcessingUtils.MixSignals(
        pure_tone, white_noise)

    # Compute scores for increasingly distorted pure tone signals.
    scores = [None, None, None]
    for index, tested_signal in enumerate([pure_tone, noisy_tone, white_noise]):
      # Save signal.
      tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav')
      signal_processing.SignalProcessingUtils.SaveWav(
          tmp_filepath, tested_signal)

      # Compute score.
      eval_score_worker.SetTestedSignalFilepath(tmp_filepath)
      eval_score_worker.Run(self._output_path)
      scores[index] = eval_score_worker.score

      # Remove output file to avoid caching.
      os.remove(eval_score_worker.output_filepath)

    # Validate scores (lowest score with a pure tone).
    self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)]))
